import torch.utils.data
import torchvision

from .SHA import build as build_sha
from .CC import build as build_cc

data_path = {
    'SHA': '/root/autodl-fs/hc/datasets/ShanghaiTech/part_A',
    'CC': '/root/autodl-fs/hc/processed_dataset/dronergbt',
}

def build_dataset(image_set, args):
    args.data_path = data_path[args.dataset_file]
    if args.dataset_file == 'SHA':
        return build_sha(image_set, args)
    if args.dataset_file == 'CC':
        return build_cc(image_set, args)
    raise ValueError(f'dataset {args.dataset_file} not supported')
